#pragma once
//a tensor-like data structure to store data

#include "../common.hpp"
#include "../Utility/Log.hpp"
#include "../Utility/IOObject.hpp"
#include "Vector2.hpp"
#include "Vector3.hpp"
#include "Vector4.hpp"
#include "Vector.hpp"


//low case function is STL competible, therefore at() takes int and size() return int;
//Capital function takes VectorBase, At() takes VectorBase and Size() return VectorBase
//[] take both

namespace zzz {
template<int N, typename T>
class ArrayBase {
public:
  // type definitions
  typedef T              value_type;
  typedef T*             iterator;
  typedef const T*       const_iterator;
  typedef T&             reference;
  typedef const T&       const_reference;
  typedef zuint          size_type;
  typedef size_type      difference_type;

protected:
  T *v;
  VectorBase<N,size_type> sizes;
  VectorBase<N,size_type> subsizes;
  size_type size_all;

public:
  //c'tor and d'tor
  ArrayBase():v(NULL),sizes(0),subsizes(0),size_all(0){}
  explicit ArrayBase(const VectorBase<N,size_type> &size):v(NULL),sizes(0),subsizes(0),size_all(0){SetSize(size);}
  ArrayBase(const T *data, const VectorBase<N,size_type> &size):v(NULL),sizes(0),subsizes(0),size_all(0){SetSize(size); SetData(data);}
  ArrayBase(const ArrayBase<N, T> &other):v(NULL),sizes(0),subsizes(0),size_all(0){SetSize(other.sizes); SetData(other.v);}
  virtual ~ArrayBase(){if (v) delete[] v; v=NULL;}

  //Direct pointer
  T* Data(){return v;}
  const T* Data() const {return v;}

  //mutator
  void SetData(const T *data){IOObject<T>::CopyData(v,data,size_all);}
  void SetSize(const VectorBase<N,size_type> &size) {
    if (sizes==size) return;
    if (v) {
      delete[] v;
      v = NULL;
    }
    sizes=size;
    subsizes[N-1]=1;
    for (int i=N-2; i>=0; i--) subsizes[i]=subsizes[i+1]*sizes[i+1];
    size_all=subsizes[0]*sizes[0];
    if (size_all != 0) {
      v=new T[size_all];
      ZCHECK_NOT_NULL(v)<<"In ArrayBase, cannot allocate memory size of "<<sizeof(T)*size_all/1024.0/1024.0<<" MB"<<endl;
    }
  }
  void Set(const T *data, const VectorBase<N, size_type> &size){SetSize(size); SetData(data);}
  void Set(const ArrayBase<N, T> &other){SetSize(other.sizes); SetData(other.v);}
  const ArrayBase<N, T> &operator=(const ArrayBase<N, T> &other){Set(other); return *this;}
  const ArrayBase<N, T> &operator=(const T &other){for (size_type i=0; i<size_all; i++) v[i]=other; return *this;}
  void RawCopy(void *data, size_type length){ZRCHECK_LE(length, sizeof(T)*size_all); memcpy(v, data, length);}
  void RawCopy(const ArrayBase<N, T> &other){SetSize(other.sizes); memcpy(v,other.v,sizeof(T)*size_all);}

  //reference
  const T& At(const size_type pos) const{
    ZRCHECK_LT(pos, size_all);
    return v[pos];
  }
  T& At(const size_type pos) {
    ZRCHECK_LT(pos, size_all);
    return v[pos];
  }
  const T& At(const VectorBase<N,size_type> &pos) const{return At(ToIndex(pos));}
  T& At(const VectorBase<N,size_type> &pos) {return At(ToIndex(pos));}
  const T& operator[](const size_type pos) const{return At(pos);}
  T& operator[](const size_type pos) {return At(pos);}
  const T& operator[](const VectorBase<N,size_type> &pos) const {return At(pos);}
  T& operator[](const VectorBase<N,size_type> &pos) {return At(pos);}
  const T& operator()(const VectorBase<N,size_type> &pos) const {return At(pos);}
  T& operator()(const VectorBase<N,size_type> &pos) {return At(pos);}

  //STL support
  // iterator support
  iterator begin() { return v; }
  const_iterator begin() const { return v; }
  iterator end() { return v+size_all; }
  const_iterator end() const { return v+size_all; }

  // reverse iterator support
  typedef std::reverse_iterator<iterator> reverse_iterator;
  typedef std::reverse_iterator<const_iterator> const_reverse_iterator;

  reverse_iterator rbegin() { return reverse_iterator(end()); }
  const_reverse_iterator rbegin() const { return const_reverse_iterator(end());}
  reverse_iterator rend() { return reverse_iterator(begin()); }
  const_reverse_iterator rend() const {return const_reverse_iterator(begin());}

  // at() with range check
  reference at(size_type i) { return At(i); }
  const_reference at(size_type i) const { return At(i); }

  // front() and back()
  reference front() { return v[0]; }
  const_reference front() const { return v[0];}
  reference back() { return v[size_all-1]; }
  const_reference back() const { return v[size_all-1]; }

  // size is constant
  size_type size() const { return size_all; }
  size_type Size(size_type i) const { return sizes[i]; }
  VectorBase<N,size_type> Size() const { return sizes; }
  bool empty() const { return v == NULL; }
  size_type max_size() const { return size_all; }
  void clear() { this->Clear(); }
  void Clear() { SetSize(VectorBase<N,size_type>(0)); }

  // index
  size_type ToIndex(const VectorBase<N,size_type> &pos) const {return pos.Dot(subsizes);}
  VectorBase<N,size_type> ToIndex(size_type idx) const {
    VectorBase<N,size_type> pos;
    for (size_type i=0; i<N; i++) {
      pos[i]=idx/subsizes[i];
      idx-=pos[i]*subsizes[i];
    }
    return pos;
  }
  bool CheckIndex(const VectorBase<N,size_type> &pos) const {for (size_type i=0; i<N; i++) if (pos[i]>=sizes[i]) return false; return true;}
  bool CheckIndex(const size_type &pos) const {return pos<size_all;}

  //addition
  inline const ArrayBase<N, T> operator+(const ArrayBase<N, T> &a) const {
    ZCHECK_EQ(sizes, a.sizes);
    ArrayBase<N, T> ret;
    for (unsigned int i=0; i<size_all; i++)  ret.v[i]=v[i]+a.v[i];
    return ret;
  }
  inline const ArrayBase<N, T> operator+(const T &a) const {
    ArrayBase<N, T> ret;
    for (unsigned int i=0; i<size_all; i++)  ret.v[i]=v[i]+a;
    return ret;
  }
  inline void operator+=(const ArrayBase<N, T> &a) {
    ZCHECK_EQ(sizes, a.sizes);
    for (unsigned int i=0; i<size_all; i++)  v[i]+=a.v[i];
  }
  inline void operator+=(const T &a) {
    for (unsigned int i=0; i<size_all; i++)  v[i]+=a;
  }

  //subtraction
  inline const ArrayBase<N, T> operator-(const ArrayBase<N, T> &a) const {
    ZCHECK_EQ(sizes, a.sizes);
    ArrayBase<N, T> ret;
    for (unsigned int i=0; i<size_all; i++)  ret.v[i]=v[i]-a.v[i];
    return ret;
  }
  inline const ArrayBase<N, T> operator-(const T &a) const {
    ArrayBase<N, T> ret;
    for (unsigned int i=0; i<size_all; i++)  ret.v[i]=v[i]-a;
    return ret;
  }
  inline void operator-=(const ArrayBase<N, T> &a) {
    ZCHECK_EQ(sizes, a.sizes);
    for (unsigned int i=0; i<size_all; i++)  v[i]-=a.v[i];
  }
  inline void operator-=(const T &a) {
    for (unsigned int i=0; i<size_all; i++)  v[i]-=a;
  }

  //multiplication
  inline const ArrayBase<N, T> operator*(const ArrayBase<N, T> &a) const {
    ZCHECK_EQ(sizes, a.sizes);
    ArrayBase<N, T> ret;
    for (unsigned int i=0; i<size_all; i++)  ret.v[i]=v[i]*a.v[i];
    return ret;
  }
  inline const ArrayBase<N, T> operator*(const T &a) const {
    ArrayBase<N, T> ret;
    for (unsigned int i=0; i<size_all; i++)  ret.v[i]=v[i]*a;
    return ret;
  }
  inline void operator*=(const ArrayBase<N, T> &a) {
    ZCHECK_EQ(sizes, a.sizes);
    for (unsigned int i=0; i<size_all; i++)  v[i]*=a.v[i];
  }
  inline void operator*=(const T &a) {
    for (unsigned int i=0; i<size_all; i++)  v[i]*=a;
  }

  //division
  inline const ArrayBase<N, T> operator/(const ArrayBase<N, T> &a) const {
    ZCHECK_EQ(sizes, a.sizes);
    ArrayBase<N, T> ret;
    for (unsigned int i=0; i<size_all; i++)  ret.v[i]=v[i]/a.v[i];
    return ret;
  }
  inline const ArrayBase<N, T> operator/(const T &a) const {
    ArrayBase<N, T> ret;
    for (unsigned int i=0; i<size_all; i++)  ret.v[i]=v[i]/a;
    return ret;
  }
  inline void operator/=(const ArrayBase<N, T> &a) {
    ZCHECK_EQ(sizes, a.sizes);
    for (unsigned int i=0; i<size_all; i++)  v[i]/=a.v[i];
  }
  inline void operator/=(const T &a) {
    for (unsigned int i=0; i<size_all; i++)  v[i]/=a;
  }

  //Min and Max
  inline T Max() {
    T maxv=v[0];
    for (unsigned int i=1; i<size_all; i++)  if (maxv<v[i]) maxv=v[i];
    return maxv;
  }
  inline int MaxPos() {
    T maxv=v[0];
    int pos=0;
    for (unsigned int i=1; i<size_all; i++) if (maxv<v[i]) {maxv=v[i];pos=i;}
    return pos;
  }
  inline T Min() {
    T minv=v[0];
    for (unsigned int i=1; i<size_all; i++) if (minv>v[i]) minv=v[i];
    return minv;
  }
  inline int MinPos() {
    T minv=v[0];
    int pos=0;
    for (unsigned int i=1; i<size_all; i++) if (minv>v[i]) {minv=v[i];pos=i;}
    return pos;
  }
  inline T AbsMax() {
    T maxv=v[0];
    for (unsigned int i=1; i<size_all; i++)  if (maxv<abs(v[i])) maxv=v[i];
    return maxv;
  }
  inline int AbsMaxPos() {
    T maxv=v[0];
    int pos=0;
    for (unsigned int i=1; i<size_all; i++) if (maxv<abs(v[i])) {maxv=v[i];pos=i;}
    return pos;
  }
  inline T AbsMin() {
    T minv=v[0];
    for (unsigned int i=1; i<size_all; i++) if (minv>abs(v[i])) minv=v[i];
    return minv;
  }
  inline int AbsMinPos() {
    T minv=v[0];
    int pos=0;
    for (unsigned int i=1; i<size_all; i++) if (minv>abs(v[i])) {minv=v[i];pos=i;}
    return pos;
  }

  // math
  inline T Dot(const ArrayBase<N, T> &a) const {
    ZCHECK_EQ(sizes, a.sizes);
    T ret=0;
    for (unsigned int i=0; i<size_all; i++) ret+=v[i]*a.v[i]; 
    return ret;
  }
  inline T Sum() const{T ret=0;for (unsigned int i=0; i<size_all; i++) ret+=v[i];return ret;}
  inline void Negative() {for (size_type i=0; i<size_all; i++) v[i]=-v[i];}
  inline T Len() const {return (T)sqrt((double)LenSqr());}
  inline T LenSqr() const {return Dot(*this);}

  inline T Normalize() {T norm=Len(); *this/=norm; return norm;}
  inline ArrayBase<N, T> Normalized() const {return *this/Len();}

  inline T DistTo(const ArrayBase<N, T> &a) const {return (T)sqrt((double)DistToSqr(a));}
  inline T DistToSqr(const ArrayBase<N, T> &a) const {ArrayBase<N, T> diff(*this-a); return diff.Dot(diff);}
  inline ArrayBase<N, T> Abs() {
    ArrayBase<N, T> ret;
    for (size_type i=0; i<N; i++) ret[i]=Abs(at(i));
    return ret;
  }

  // comparisons
  bool operator== (const ArrayBase<N, T>& y) const {return std::equal(begin(), end(), y.begin());}
  bool operator< (const ArrayBase<N, T>& y) const {return std::lexicographical_compare(begin(),end(), y.begin(), y.end());}
  bool operator!= (const ArrayBase<N, T>& y) const {return !(*this==y);}
  bool operator> (const ArrayBase<N, T>& y) const {return y<*this;}
  bool operator<= (const ArrayBase<N, T>& y) const {return !(y<*this);}
  bool operator>= (const ArrayBase<N, T>& y) const {return !(*this<y);}

  void Zero(zuchar x=0){memset(v, x, sizeof(T)*size_all);}
  void Fill(const T &a) { for (size_type i=0; i<size_all; i++) v[i]=a; }
  void Fill(const VectorBase<N,size_type> &size, const T &a) {
    SetSize(size);
    Fill(a);
  }

  //IO
  friend inline ostream& operator<<(ostream& os,const ArrayBase<N, T> &me) {
    for (int i=0; i<me.size(); i++)
      os<<me.v[i]<<' ';
    return os;
  }
  friend inline istream& operator>>(istream& is,ArrayBase<N, T> &me) {
    for (int i=0; i<me.size(); i++)
      is>>me.v[i];
    return is;
  }
  void SaveToFileA(ostream &fo) const {
    fo<<sizes;
    for (int i=0; i<size_all; i++)
      fo<<v[i]<<' ';
  }
  void LoadFromFileA(istream &fi) {
    VectorBase<N,size_type> s;
    fi>>s;
    SetSize(s);
    for (int i=0; i<size_all; i++)
      fi>>v[i];
  }
  void WriteFileB(FILE *fp) const {
    sizes.WriteFileB(fp);
    fwrite(v,sizeof(T),size_all,fp);
  }
  void ReadFileB(FILE *fp) {
    VectorBase<N,size_type> s;
    s.ReadFileB(fp);
    SetSize(s);
    fread(v,sizeof(T),size_all,fp);
  }

  T Interpolate(const VectorBase<N,double> &coord) const
  {
    Vector<N,double> ratio;
    Vector<N, size_type> Min;
    Vector<N, size_type> Max;
    for (int i=0; i<N; i++) 
    {
      if (coord[i]<0) return T(0);
      size_type tmp=floor(coord[i]);
      if (tmp>=Size(i)-1) return T(0);
      Min[i]=tmp;
      Max[i]=tmp+1;
      ratio[i]=1.0-coord[i]+tmp;
    }

    //rip out the cell
    ArrayBase<N, T> data(Vector<N,size_type>(2));
    {
      VectorBase<N,size_type> thiscoord(0);
      bool good=true;
      while(good)
      {
        //interpolate
        if (thiscoord==Vector<N,size_type>(1))
          int dummy=1;
        data(thiscoord)=At(thiscoord+Min);
        //next coord
        thiscoord[N-1]++;
        //normalize
        int curdim=N-1;
        while(true)
        {
          if (thiscoord[curdim]<2) break; //good
          if (curdim==0)  //overflow
          {
            good=false;
            break;
          }
          thiscoord[curdim]=0;
          curdim--;
          thiscoord[curdim]++;
        }
      }
    }

    for (int i=0; i<N; i++)
    {
      VectorBase<N,size_type> thiscoord(0);
      bool good=true;
      while(good)
      {
        //interpolate
        VectorBase<N,size_type> another(thiscoord);
        another[i]=1;
        data(thiscoord)=data(thiscoord)*ratio[i]+data(another)*(1.0-ratio[i]);
        //next coord
        thiscoord[N-1]++;
        //normalize
        int curdim=N-1;
        while(true)
        {
          if (curdim==i)  //overflow
          {
            good=false;
            break;
          }
          if (thiscoord[curdim]<2) break; //good
          thiscoord[curdim]=0;
          curdim--;
          thiscoord[curdim]++;
        }
      }
    }
    return data(VectorBase<N,size_type>(0));
  }
};

//write all these stuff just to avoid direct memory copy when construct and copy
template <unsigned int N, typename T>
class Array: public ArrayBase<N, T>
{
public:
  //constructor
  Array(void){}
  explicit Array(const VectorBase<N,size_type> &size):ArrayBase<N, T>(size){}
  Array(const T *data, const VectorBase<N,size_type> &size):ArrayBase<N, T>(data, size){}
  Array(const Array<N, T>& a):ArrayBase<N, T>(a) {}
  Array(const ArrayBase<N, T>& a):ArrayBase<N, T>(a) {}

  //assign
  inline const Array<N, T> &operator=(const Array<N, T>& a){ArrayBase<N, T>::operator =(a); return *this;}
  using ArrayBase<N, T>::operator=;
};

// IOObject

template<unsigned int N, typename T>
class IOObject<ArrayBase<N, T> >
{
public:
  static void WriteFileB(FILE *fp, const ArrayBase<N, T> &src) {
    Vector<N, zuint64> size(src.Size());
    IOObj::WriteFileB(fp, size);
    IOObj::WriteFileB(fp, src.Data(), src.size());
  }
  static void ReadFileB(FILE *fp, ArrayBase<N, T>& dst) {
    Vector<N, zuint64> s;
    IOObj::ReadFileB(fp, s);
    dst.SetSize(Vector<N, zuint>(s));
    IOObj::ReadFileB(fp, dst.Data(), dst.size());
  }
  static void WriteFileR(RecordFile &fp, const zint32 label, const ArrayBase<N, T>& src) {
    Vector<N, zuint64> size(src.Size());
    IOObj::WriteFileR(fp, label, size);
    IOObj::WriteFileR(fp, src.Data(), src.size());
  }
  static void ReadFileR(RecordFile &fp, const zint32 label, ArrayBase<N, T>& dst) {
    Vector<N, zuint64> size;
    IOObj::ReadFileR(fp, label, size);
    dst.SetSize(Vector<N, zuint>(size));
    IOObj::ReadFileR(fp, dst.Data(), dst.size());
  }
  /// When the object inside Array is not SIMPLE_IOOBJECT, it must access 1 by 1,
  /// instead of access as a raw array.
  static const int RF_SIZE = 1;
  static const int RF_DATA = 2;
  static void WriteFileR1By1(RecordFile &rf, const zint32 label, const ArrayBase<N, T>& src)
  {
    VectorBase<N, zuint64> len = src.Size();
    rf.WriteChildBegin(label);
    IOObj::WriteFileR(rf, RF_SIZE, len);
    rf.WriteRepeatBegin(RF_DATA);
    zuint sizeall = src.size();
    for (zuint i = 0; i < sizeall; ++i) {
      rf.WriteRepeatChild();
      IOObj::WriteFileR(rf, src[i]);
    }
    rf.WriteRepeatEnd();
    rf.WriteChildEnd();
  }
  static void ReadFileR1By1(RecordFile &rf, const zint32 label, ArrayBase<N, T>& dst)
  {
    dst.clear();
    if (!rf.LabelExist(label)) {
      return;
    }
    rf.ReadChildBegin(label);
    VectorBase<N, zuint64> len;
    IOObj::ReadFileR(rf, RF_SIZE, len);
    dst.SetSize(len);
    rf.ReadRepeatBegin(RF_DATA);
    zuint i = 0;
    while(rf.ReadRepeatChild()) {
      T v;
      IOObj::ReadFileR(rf, v);
      dst[i++] = v;
    }
    rf.ReadRepeatEnd();
    ZCHECK_EQ(dst.size(), len) << "The length recorded is different from the actual length of data";
    rf.ReadChildEnd();
  }

};

template<zuint N, typename T>
class IOObject<Array<N, T> >
{
public:
  static void WriteFileB(FILE *fp, const Array<N, T> &src) {
    IOObject<ArrayBase<N, T> >::WriteFileB(fp, src);
  }
  static void ReadFileB(FILE *fp, Array<N, T>& dst) {
    IOObject<ArrayBase<N, T> >::ReadFileB(fp, dst);
  }
  static void WriteFileR(RecordFile &fp, const zint32 label, const Array<N, T>& src) {
    IOObject<ArrayBase<N, T> >::WriteFileR(fp, label, src);
  }
  static void ReadFileR(RecordFile &fp, const zint32 label, Array<N, T>& dst) {
    IOObject<ArrayBase<N, T> >::ReadFileR(fp, label, dst);
  }
};

template <zuint N, typename T>
inline ArrayBase<N, T> Abs(const ArrayBase<N, T> &x) {
  return x.Abs();
}


}
